#!/usr/bin/python
#coding=utf-8
import rospy
import cv2
from cv_bridge import CvBridge
import colorsys
import os
import time
import numpy as np
from keras import backend as K
from keras.applications.imagenet_utils import preprocess_input
from PIL import Image, ImageDraw, ImageFont
from nets.ssd import SSD300
from utils.utils import BBoxUtility, letterbox_image, ssd_correct_boxes
from sensor_msgs.msg import Image as ros_img
from ros_openvino.msg import Object 
cvb=CvBridge()

#--------------------------------------------#
#   使用自己训练好的模型预测需要修改2个参数
#   model_path和classes_path都需要修改！
#   如果出现shape不匹配
#   一定要注意训练时的NUM_CLASSES、
#   model_path和classes_path参数的修改
#--------------------------------------------#
class SSD(object):
    _defaults = {
        "model_path"        : 'model_data/wrc_model.h5',
        "classes_path"      : 'model_data/wrc_classes.txt',
        "input_shape"       : (300, 300, 3),
        "confidence"        : 0.3,
        "nms_iou"           : 0.45,
        'anchors_size'      : [30,60,111,162,213,264,315],
        #---------------------------------------------------------------------#
        #   该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize，
        #   在多次测试后，发现关闭letterbox_image直接resize的效果更好
        #---------------------------------------------------------------------#
        "letterbox_image"   : False,
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    #---------------------------------------------------#
    #   初始化ssd
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        self.class_names = self._get_class()
        self.sess = K.get_session()
        self.generate()
        self.bbox_util = BBoxUtility(self.num_classes, nms_thresh=self.nms_iou)
        rospy.init_node('object_detection_wrc2021', anonymous=False)
        self.object_result_pub=rospy.Publisher("object_detection/results",Object,queue_size=1)
        self.object_img_pub=rospy.Publisher("object_detection/output_image",ros_img,queue_size=10)
        self.object_dectction_msg=Object()
        self.img_input=np.zeros((640,480,3),np.uint8)
        rospy.Subscriber("camera/rgb/image_raw", ros_img, self.read_img)
        self.detect_image()

    #---------------------------------------------------#
    #   获得所有的分类
    #---------------------------------------------------#
    def _get_class(self):
        classes_path = os.path.expanduser(self.classes_path)
        with open(classes_path) as f:
            class_names = f.readlines()
        class_names = [c.strip() for c in class_names]
        return class_names

    #---------------------------------------------------#
    #   载入模型
    #---------------------------------------------------#
    def generate(self):
        model_path = os.path.expanduser(self.model_path)
        assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
        
        #-------------------------------#
        #   计算总的类的数量
        #-------------------------------#
        self.num_classes = len(self.class_names) + 1

        #-------------------------------#
        #   载入模型与权值
        #-------------------------------#
        self.ssd_model = SSD300(self.input_shape, self.num_classes, anchors_size=self.anchors_size)
        self.ssd_model.load_weights(self.model_path, by_name=True)

        print('{} model, anchors, and classes loaded.'.format(model_path))

        # 画框设置不同的颜色
        hsv_tuples = [(x / len(self.class_names), 1., 1.)
                      for x in range(len(self.class_names))]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(
            map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                self.colors))

    #---------------------------------------------------#
    #   检测图片
    #---------------------------------------------------#
    def read_img(self,image):
        frame=cvb.imgmsg_to_cv2(image,desired_encoding="passthrough")
        self.img_input=frame
        
    def detect_image(self):
        while not rospy.is_shutdown():
            fps = 0.0
            t1 = time.time()        
            frame = cv2.cvtColor(self.img_input,cv2.COLOR_BGR2RGB)
            frame2=frame.copy()

            #转变成Image类型图像
            frame = Image.fromarray(np.uint8(frame))
            #通道转换，去掉透明通道 4通道变成3通道
            image = frame.convert('RGB')
            image_shape = np.array(np.shape(image)[0:2])
            # #---------------------------------------------------------#
            # #   给图像增加灰条，实现不失真的resize
            # #   也可以直接resize进行识别
            # #---------------------------------------------------------#
            if self.letterbox_image:
                crop_img = np.array(letterbox_image(image, (self.input_shape[1],self.input_shape[0])))
            else:
                crop_img = image.resize((self.input_shape[1],self.input_shape[0]), Image.BICUBIC)
            photo = np.array(crop_img,dtype = np.float64)
            # #-----------------------------------------------------------#
            # #   图片预处理，归一化。
            # #-----------------------------------------------------------#

            photo = preprocess_input(np.reshape(photo,[1,self.input_shape[0], self.input_shape[1], 3]))
            preds = self.ssd_model.predict(photo)

            #-----------------------------------------------------------#
            #   将预测结果进行解码
            #-----------------------------------------------------------#
            results = self.bbox_util.detection_out(preds, confidence_threshold=self.confidence)
        
            # #--------------------------------------#
            # #   如果没有检测到物体，则返回原图
            # #--------------------------------------#
            if len(results[0])<=0:
                #frame = cv2.cvtColor(np.asarray(image),cv2.COLOR_RGB2BGR)
                #frame = cv2.cvtColor(np.asarray(image))
                object_img=cvb.cv2_to_imgmsg(frame2,"bgr8")
                self.object_img_pub.publish(object_img)
                continue
            #-----------------------------------------------------------#
            #   筛选出其中得分高于confidence的框 
            #-----------------------------------------------------------#
            det_label = results[0][:, 0]
            det_conf = results[0][:, 1]
            det_xmin, det_ymin, det_xmax, det_ymax = results[0][:, 2], results[0][:, 3], results[0][:, 4], results[0][:, 5]
            top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence]
            top_conf = det_conf[top_indices]
            top_label_indices = det_label[top_indices].tolist()
            top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices],-1),np.expand_dims(det_ymin[top_indices],-1),np.expand_dims(det_xmax[top_indices],-1),np.expand_dims(det_ymax[top_indices],-1)
        
                #-----------------------------------------------------------#
                #   去掉灰条部分
                #-----------------------------------------------------------#
            if self.letterbox_image:
                boxes = ssd_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.input_shape[0],self.input_shape[1]]),image_shape)
            else:
                top_xmin = top_xmin * image_shape[1]
                top_ymin = top_ymin * image_shape[0]
                top_xmax = top_xmax * image_shape[1]
                top_ymax = top_ymax * image_shape[0]
                boxes = np.concatenate([top_ymin,top_xmin,top_ymax,top_xmax], axis=-1)

                font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))

                thickness = max((np.shape(image)[0] + np.shape(image)[1]) // self.input_shape[0], 1)

                for i, c in enumerate(top_label_indices):
                    predicted_class = self.class_names[int(c)-1]
                    score = top_conf[i]

                    top, left, bottom, right = boxes[i]
                    top = top - 5
                    left = left - 5
                    bottom = bottom + 5
                    right = right + 5

                    top = max(0, np.floor(top + 0.5).astype('int32'))
                    left = max(0, np.floor(left + 0.5).astype('int32'))
                    bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32'))
                    right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32'))
                    self.object_dectction_msg.label=predicted_class
                    self.object_dectction_msg.confidence=score
                    self.object_dectction_msg.x=left
                    self.object_dectction_msg.y=top
                    self.object_dectction_msg.width=right-left
                    self.object_dectction_msg.height=bottom-top
                    self.object_result_pub.publish(self.object_dectction_msg)

                    # 画框框
                    label = '{} {:.2f}'.format(predicted_class, score)
                    draw = ImageDraw.Draw(image)
                    label_size = draw.textsize(label, font)
                    label = label.encode('utf-8')
                    # print(label, top, left, bottom, right)

                    if top - label_size[1] >= 0:
                        text_origin = np.array([left, top - label_size[1]])
                    else:
                        text_origin = np.array([left, top + 1])

                    for i in range(thickness):
                        draw.rectangle(
                            [left + i, top + i, right - i, bottom - i],
                            outline=self.colors[int(c)-1])
                    draw.rectangle(
                        [tuple(text_origin), tuple(text_origin + label_size)],
                        fill=self.colors[int(c)-1])
                    draw.text(text_origin, str(label), fill=(0, 0, 0), font=font)
                    del draw
                    frame = cv2.cvtColor(np.asarray(image),cv2.COLOR_RGB2BGR)
                    fps  = ( fps + (1./(time.time()-t1)) ) / 2
                        #print("fps= %.2f"%(fps))
                    frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
                    object_img=cvb.cv2_to_imgmsg(frame,"rgb8")
                    self.object_img_pub.publish(object_img)
if __name__ == '__main__':
    try:
        SSD()
        rospy.spin()
    except rospy.ROSInterruptException:
        rospy.loginfo("Follower node terminated.")
